/*
 * Author: tdanford
 * Date: Oct 27, 2008
 */
package edu.mit.csail.cgs.sigma.expression.segmentation.viz;

import java.sql.SQLException;
import java.util.*;
import java.awt.*;
import java.awt.event.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;

import javax.imageio.ImageIO;
import javax.swing.*;

import edu.mit.csail.cgs.datasets.general.Region;
import edu.mit.csail.cgs.datasets.general.StrandedRegion;
import edu.mit.csail.cgs.datasets.species.Genome;
import edu.mit.csail.cgs.sigma.expression.BaseExpressionProperties;
import edu.mit.csail.cgs.sigma.expression.SudeepExpressionProperties;
import edu.mit.csail.cgs.sigma.expression.segmentation.Segment;
import edu.mit.csail.cgs.sigma.expression.transcription.*;
import edu.mit.csail.cgs.sigma.expression.transcription.fitters.TAFit;
import edu.mit.csail.cgs.sigma.expression.transcription.identifiers.TranscriptIdentifier;
import edu.mit.csail.cgs.sigma.expression.transcription.viz.ModelTranscriptCallPainter;
import edu.mit.csail.cgs.sigma.expression.workflow.models.FileInputData;
import edu.mit.csail.cgs.sigma.expression.workflow.models.InputSegmentation;
import edu.mit.csail.cgs.sigma.genes.GeneAnnotationProperties;
import edu.mit.csail.cgs.sigma.viz.GenePainter;
import edu.mit.csail.cgs.utils.Pair;
import edu.mit.csail.cgs.utils.database.DatabaseException;
import edu.mit.csail.cgs.utils.models.Model;
import edu.mit.csail.cgs.viz.colors.Coloring;
import edu.mit.csail.cgs.viz.eye.*;
import edu.mit.csail.cgs.viz.paintable.DoubleBufferedPaintable;
import edu.mit.csail.cgs.viz.paintable.PaintablePanel;
import edu.mit.csail.cgs.viz.paintable.PaintableScale;
import edu.mit.csail.cgs.viz.paintable.VerticalScalePainter;
import edu.mit.csail.cgs.viz.utils.FileChooser;

/**
 * This is a Swing component, that will let us quickly see the expression values 
 * and segments which are being generated by a Segmenter.  
 * 
 * @author tdanford
 */
public class SegmentViz extends PaintablePanel {
	
	private ModelLocatedValues valuesPaintable, predictedPaintable;
	private ModelRangeValues segPaintable;
	private ModelSegmentValues fittedPaintable;
	private VerticalScalePainter scalePaintable;
	private ModelTranscriptCallPainter callPaintable;

	private Genome genome;
	private GenePainter genePaintable;
	private Region region;
	
	private String chrom, strand;
	private Integer[] locs; 
	
	private FileInputData data;
	private InputSegmentation segmentation;
	private ArrayList<Cluster> clusters;
	private ArrayList<TAFit> fits;
	private ArrayList<Call> calls;
	
	private Set<Model> nearModel; 
	private Collection<Pair<Rectangle,Model>> selectedSegmentModels;
	private Point modelPoint, mousePoint;
	
	private boolean probeMarking, transcriptMarking;
	
	public SegmentViz() { 
		super();
		
		probeMarking = false;
		transcriptMarking = false;
		
		data = null;
		segmentation = null;
		clusters = new ArrayList<Cluster>();
		fits = new ArrayList<TAFit>();
		calls = new ArrayList<Call>();
		
		valuesPaintable = new ModelLocatedValues();
		predictedPaintable = new ModelLocatedValues();
		segPaintable = new ModelRangeValues();
		fittedPaintable = new ModelSegmentValues();
		//fittedPaintable = new ModelSegmentRegionValues();
		callPaintable = new ModelTranscriptCallPainter();
		
		BaseExpressionProperties bps = new SudeepExpressionProperties();
		String key = "original_s288c_mat_a";
		String strain = bps.parseStrainFromExptKey(key);
		GeneAnnotationProperties gaps = new GeneAnnotationProperties();
		
		try { 
			genome = bps.getGenome(bps.parseStrainFromExptKey(key));
			genePaintable = new GenePainter(gaps, strain);
		} catch(DatabaseException e) { 
			genome = null;
			genePaintable = null;
		}
		
		DoubleBufferedPaintable vbuffered = new DoubleBufferedPaintable(valuesPaintable);
		vbuffered.paintBackground(true);
		setPaintable(vbuffered);
		
		String bk = ModelRangeValues.boundsKey;
		String sk = ModelLocatedValues.scaleKey;
		
		valuesPaintable.setProperty(ModelLocatedValues.stemKey, false);
		predictedPaintable.setProperty(ModelLocatedValues.stemKey, false);
		
		valuesPaintable.synchronizeProperty(bk, segPaintable);
		valuesPaintable.synchronizeProperty(bk, predictedPaintable);
		valuesPaintable.synchronizeProperty(bk, fittedPaintable);
		valuesPaintable.synchronizeProperty(bk, callPaintable);
		
		valuesPaintable.synchronizeProperty("scale", fittedPaintable);
		valuesPaintable.synchronizeProperty("scale", predictedPaintable);
		fittedPaintable.setProperty(ModelSegmentValues.drawWeightsKey, Boolean.TRUE);

		PaintableScale scale = valuesPaintable.getPropertyValue("scale");
		scalePaintable = new VerticalScalePainter(scale);
		
		predictedPaintable.setProperty(ModelLocatedValues.colorKey, Color.magenta);
		
		addMouseMotionListener(new MouseMotionAdapter() { 
			public void mouseMoved(MouseEvent e) { 
				Pair<Point,Set<Model>> pp = 
					valuesPaintable.findNearestDrawnPoint(e.getPoint());
				if(transcriptMarking) { 
					selectedSegmentModels = fittedPaintable.findDrawnRects(e.getPoint());
					repaint();
				}
				if(probeMarking && pp != null) { 
					nearModel = pp.getLast();
					modelPoint = pp.getFirst();
					mousePoint = e.getPoint();
					repaint();
				}
			}
		});
		
		addComponentListener(new ComponentListener() {

			public void componentHidden(ComponentEvent arg0) {
				nearModel = null;
				modelPoint = null;
				repaint();
			}

			public void componentMoved(ComponentEvent arg0) {
				nearModel = null;
				modelPoint = null;
				repaint();
			}

			public void componentResized(ComponentEvent arg0) {
				nearModel = null;
				modelPoint = null;
				repaint();
			}

			public void componentShown(ComponentEvent arg0) {
				nearModel = null;
				modelPoint = null;
				repaint();
			} 
			
		});
	}
	
	public void setGenePaintable(Genome g, GenePainter gp) { 
		genome = g;
		genePaintable = gp;
		region = null;
	}
	
	public void setBounds(String chr, String str, int start, int end) { 
		PropertyValueWrapper<Integer[]> boundsWrapper = (PropertyValueWrapper<Integer[]>)valuesPaintable.getProperty(ModelRangeValues.boundsKey);
		boundsWrapper.setValue(new Integer[] { start, end });
		
		valuesPaintable.setProperty(ModelRangeValues.boundsKey, boundsWrapper);
		predictedPaintable.setProperty(ModelRangeValues.boundsKey, boundsWrapper);
		callPaintable.setProperty(ModelRangeValues.boundsKey, boundsWrapper);
		segPaintable.setProperty(ModelRangeValues.boundsKey, boundsWrapper);
		fittedPaintable.setProperty(ModelRangeValues.boundsKey, boundsWrapper);
		
		chrom = chr;
		strand = str;
		
		if(genome != null) { 
			region = new StrandedRegion(genome, chrom, start, end, strand.charAt(0));
			genePaintable.setRegion(region);
			System.out.println("Set Region: " + region.toString());
		}
	}

	public void synchronizeScales(SegmentViz sv) { 
		String sk = ModelLocatedValues.scaleKey;
		PaintableScale oldScale = sv.valuesPaintable.getPropertyValue(sk);
		double max = oldScale.getMax(), min = oldScale.getMin();
		
		valuesPaintable.synchronizeProperty(sk, sv.fittedPaintable);
		valuesPaintable.synchronizeProperty(sk, sv.valuesPaintable);
		
		PaintableScale scale = valuesPaintable.getPropertyValue(sk);
		scale.updateScale(min);
		scale.updateScale(max);
		
		sv.scalePaintable = new VerticalScalePainter(scale);
		sv.repaint();
	}
	
	public void setData(Integer[] locations, Double[] values) { 
		if(locations.length != values.length) { throw new IllegalArgumentException(); }
		
		String scaleKey = valuesPaintable.scaleKey;
		PaintableScale scale = valuesPaintable.getPropertyValue(scaleKey);
		
		valuesPaintable.clearModels();
		fittedPaintable.clearModels();
		callPaintable.clearModels();
		segPaintable.clearModels();
		
		locs = locations.clone();
		
		for(int i = 0; i < locations.length; i++) { 
			valuesPaintable.addModel(new Datapoint(locations[i], values[i]));
		}
		
		data = new FileInputData(chrom, strand, locations, new Double[][] { values });
		segmentation = null;
		clusters.clear(); 
		fits.clear();
		calls.clear();
		
		repaint();
	}
	
	public void setPredictedData(Integer[] indices, Double[] values) { 
		String scaleKey = predictedPaintable.scaleKey;
		PaintableScale scale = predictedPaintable.getPropertyValue(scaleKey);
		
		predictedPaintable.clearModels();
		
		for(int i = 0; i < values.length; i++) { 
			//int idx = cluster.segments[0].start + i;
			int idx = indices[i];
			predictedPaintable.addModel(new Datapoint(locs[idx], values[i]));
		}		
		repaint();
	}
	
	public void setSegments(Collection<Segment> segs) { 
		segPaintable.clearModels();
		
		for(Segment s : segs) { 
			SegmentRange r = new SegmentRange(s, locs);
			segPaintable.addModel(r);
		}
		
		repaint();
	}
	
	/**
	public void setArrangement(TranscriptArrangement arr) { 
		callPaintable.clearModels();
		System.out.println("Setting arrangement: " + arr.toString());
		Integer[] breakpoints = arr.cluster.breakpoints();
		
		for(int i = 0; i < arr.calls.length; i++) { 
			int start = arr.cluster.locations[breakpoints[arr.calls[i].start]];
			int end = arr.cluster.locations[breakpoints[arr.calls[i].end]];
			System.out.println("\tCall: " + start + ", " + end);
			callPaintable.addModel(new ModelTranscriptCallPainter.CallModel(start, end));
		}
		repaint();
	}
	**/
	
	/**
	 * Important -- unlike the method above (setArrangement), this method assumes that the 
	 * coordinates of the transcript calls (the 'start' and 'end' fields) have *already been 
	 * converted* into real (base) coordinates.
	 * 
	 * @param arrangement
	 */
	public void setTranscriptCalls(Collection<Call> cs) {
		calls.clear();
		callPaintable.clearModels();
		
		for(Call c : cs) {
			calls.add(c);
			callPaintable.addModel(new ModelTranscriptCallPainter.CallModel(c.start, c.end));
		}
		repaint();
	}
	
	public void setFitted(Iterator<Segment> segs) {
		ArrayList<Segment> segList = new ArrayList<Segment>();
		
		fittedPaintable.clearModels();
		callPaintable.clearModels();

		while(segs.hasNext()) { 
			Segment s = segs.next();
			segList.add(s);
			FittedRange r = new FittedRange(s, locs, s.params, s.segmentType, s.shared);
			fittedPaintable.addModel(r);
		}

		segmentation = new InputSegmentation(data, segList);
		clusters.clear();
		
		// TODO: fix me.
		Iterator<Cluster> cs = null; //segmentation.clusters();
		
		while(cs.hasNext()) { 
			clusters.add(cs.next());
		}
		fits.clear();
		calls.clear();
		
		repaint();
	}
	
	protected void paintComponent(Graphics g) { 
		int w = getWidth(), h = getHeight();
		g.setColor(Color.white);
		g.fillRect(0, 0, w, h);
		
		int calls = callPaintable.size();
		int h4 = (int)Math.floor((double)h / (calls > 0 ? 6.0 : 4.0));
		int bottom = calls > 0 ? h4 * 2 : h4;
		
		int yarr = h-bottom;
		
		Graphics2D g2 = (Graphics2D)g;
		FontMetrics fm = g2.getFontMetrics();
		
		super.paintItem(g, 0, 0, w, yarr);
		
		//segPaintable.paintItem(g, 0, 0, w, yarr);
		//predictedPaintable.paintItem(g, 0, 0, w, yarr);
		fittedPaintable.paintItem(g, 0, 0, w, yarr);
		scalePaintable.paintItem(g, 0, 0, w, yarr);
		
		if(region != null) { 
			genePaintable.paintItem(g, 0, yarr, w, yarr+h4);
		}
		
		if(calls > 0) { 
			if(region != null) { 
				callPaintable.paintItem(g, 0, yarr+h4, w, h);
			} else { 
				callPaintable.paintItem(g, 0, yarr, w, h);				
			}
		} else { 
			//System.out.println("No calls to paint.");
		}
		
		String str;
		int strw = 0, strh = 0, strx = 0, stry = 0;
		
		if(probeMarking && nearModel != null && modelPoint != null) { 
			str = nearModel.toString();
			strw = fm.charsWidth(str.toCharArray(), 0, str.length());
			strh = fm.getHeight();
			
			strx = mousePoint.x;
			stry = mousePoint.y-fm.getDescent();
			
			strx -= Math.max(0, (strx+strw-w));
			stry = Math.max(stry, fm.getHeight());

			g.setColor(Color.white);
			g.fillRect(strx, stry-fm.getAscent(), strw, strh);

			g.setColor(Color.black);
			g.drawLine(mousePoint.x, mousePoint.y, modelPoint.x, modelPoint.y);
			g.drawString(str, strx, stry);
		}

		if(selectedSegmentModels != null && transcriptMarking) { 
			for(Pair<Rectangle,Model> rmp : selectedSegmentModels) { 
				Rectangle r = rmp.getFirst();
				Model m = rmp.getLast();

				str = m.toString();
				strw = fm.charsWidth(str.toCharArray(), 0, str.length());
				strh = fm.getHeight();

				strx = r.x;
				stry = r.y-fm.getDescent();

				strx -= Math.max(0, (strx+strw-w));
				stry = Math.max(stry, fm.getHeight());

				g.setColor(Color.white);
				g.fillRect(strx, stry-fm.getAscent(), strw, strh);

				g.setColor(Color.black);
				g.drawRect(r.x, r.y, r.width, r.height);
				g.drawLine(r.x, r.y, r.x+r.width, r.y+r.height);
				g.drawLine(r.x, r.y+r.height, r.x+r.width, r.y);
				g.drawString(str, strx, stry);
			}
		}
	}
	
	public static class Datapoint extends Model { 
		public Integer location;
		public Double value;
		
		public Datapoint(Integer loc, Double val) { 
			location = loc; value = val;
		}
	}
	
	public static class SegmentRange extends Model { 
		public Integer start, end;
		
		public SegmentRange(Segment s, Integer[] locs) { 
			start = locs[s.start]; 
			end = locs[s.end];
			
			//if(s.start > 0) { start -= (locs[s.start] - locs[s.start-1])/2; }
			//if(s.end < locs.length-1) { end += (locs[s.end+1] - locs[s.end])/2; }
			
			System.out.println(String.format("+Segment(%d,%d)", start, end));
		}
	}

	public static class FittedRange extends Model { 
		public Integer start, end;
		public Double[] params;
		public Integer type;
		public boolean shared;
		
		public FittedRange(Segment s, Integer[] locs, Double[] p, Integer t, boolean sh) { 
			if(s.start < 0 || s.end >= locs.length) { 
				throw new IllegalArgumentException(String.format("%d,%d outside of %d array", s.start, s.end, locs.length));
			}
			
			start = locs[s.start]; 
			end = locs[s.end];
			params = p.clone();
			type = t;
			shared = sh;
		}
	}

	public Action createSnapshotAction() {
		return new AbstractAction("Snapshot...") {
			public void actionPerformed(ActionEvent e) {
				int w = getWidth(), h = getHeight();
				FileChooser chooser = new FileChooser(null);
				File f = chooser.chooseSave();
				if(f != null) { 
					nearModel = null; modelPoint = null;
		            BufferedImage im = 
		                new BufferedImage(w, h, BufferedImage.TYPE_INT_RGB);
		            Graphics g = im.getGraphics();
		            Graphics2D g2 = (Graphics2D)g;
		            g2.setRenderingHints(new RenderingHints(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON));
		            paintComponent(g);
		            try {
						ImageIO.write(im, "png", f);
					} catch (IOException e1) {
						e1.printStackTrace();
					}
				}
			} 
		};
	}

	public void setTranscriptMarking(boolean value) {
		transcriptMarking = value;
		if(!transcriptMarking) { 
			selectedSegmentModels.clear();
		}
		repaint();
	}
	
	public void setProbeMarking(boolean value) {
		probeMarking = value;
		if(!probeMarking) { 
			modelPoint = mousePoint = null;
		}
		repaint();
	}

	public void findBestFit(TranscriptIdentifier ident) {
		ArrayList<Call> callList = new ArrayList<Call>();
		
		if(segmentation != null) { 
			for(Cluster c : clusters) { 
				TAFit fit = ident.identify(c);
				System.out.println(String.format("Cluster %d-%d -> %d segments -> %d calls",
						data.locations[c.segments[0].start], 
						data.locations[c.segments[c.segments.length-1].end],
						c.segments.length, 
						fit.arrangement.calls.length));
				
				for(int i = 0; i < fit.arrangement.calls.length; i++) { 
					Call tc = fit.arrangement.calls[i];
					Double param = fit.params[i];
					int si = tc.start, ei = tc.end-1;
					int idx1 = c.segments[si].start, idx2 = c.segments[ei].end;
					int loc1 = data.locations[idx1], loc2 = data.locations[idx2];
					System.out.println(String.format("\t%d-%d, %.3f", 
							loc1, loc2, param));
					callList.add(new Call(loc1, loc2, param));
				}
			}
		}
		
		setTranscriptCalls(callList);
	}
}
